In [1]:
import warnings
warnings.filterwarnings("ignore")
In [2]:
from IPython.display import Image, display
from IPython.core.display import HTML 
In [3]:
import numpy as np
In [4]:
import pandas as pd
In [5]:
import seaborn as sns
In [6]:
import matplotlib.pyplot as plt
%matplotlib inline
In [7]:
import random, tqdm
In [8]:
import os, cv2
In [9]:
import torch
In [10]:
import torch.nn as nn
In [11]:
from torch.utils.data import DataLoader
In [12]:
import albumentations as album
In [13]:
import segmentation_models_pytorch as smp
In [14]:
from segmentation_models_pytorch import utils
In [15]:
cpath = ''
In [16]:
DATA_DIR = f'{cpath}CVCDataset/'
In [17]:
metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
In [18]:
metadata_df = metadata_df[['frame_id', 'png_image_path', 'png_mask_path']]
In [19]:
metadata_df['png_image_path'] = metadata_df['png_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
In [20]:
metadata_df['png_mask_path'] = metadata_df['png_mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
In [21]:
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
In [22]:
valid_df = metadata_df.sample(frac=0.1, random_state=42)
In [23]:
train_df = metadata_df.drop(valid_df.index)
In [24]:
len(train_df), len(valid_df)
Out[24]:
(551, 61)
In [25]:
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
In [26]:
class_names = class_dict['class_names'].tolist()
In [27]:
class_rgb_values = class_dict[['r','g','b']].values.tolist()
In [28]:
print('All dataset classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)
All dataset classes and their corresponding RGB values in labels:
Class Names:  ['background', 'polyp']
Class RGB values:  [[0, 0, 0], [255, 255, 255]]
In [29]:
select_classes = ['background', 'polyp']
In [30]:
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
In [31]:
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]
In [32]:
print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)
Selected classes and their corresponding RGB values in labels:
Class Names:  ['background', 'polyp']
Class RGB values:  [[0, 0, 0], [255, 255, 255]]
In [33]:
def visualize(**images):
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()
In [34]:
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x
In [35]:
class EndoscopyDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            df,
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.image_paths = df['png_image_path'].tolist()
        self.mask_paths = df['png_mask_path'].tolist()
        
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.image_paths)
In [36]:
dataset = EndoscopyDataset(train_df, class_rgb_values=select_class_rgb_values)
In [37]:
random_idx = random.randint(0, len(dataset)-1)
In [38]:
image, mask = dataset[2]
In [39]:
visualize(
    original_image = image,
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)
)
In [40]:
def get_training_augmentation():
    train_transform = [
        album.HorizontalFlip(p=0.5),
    ]
    return album.Compose(train_transform)
In [41]:
def get_validation_augmentation():
    test_transform = [
        album.PadIfNeeded(min_height=288, min_width=384, always_apply=True, border_mode=0),
    ]
    return album.Compose(test_transform)
In [42]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')
In [43]:
def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)
In [44]:
augmented_dataset = EndoscopyDataset(
    train_df, 
    augmentation=get_training_augmentation(),
    class_rgb_values=select_class_rgb_values,
)
In [45]:
random_idx = random.randint(0, len(augmented_dataset)-1)
In [46]:
for idx in range(3):
    image, mask = augmented_dataset[idx]
    visualize(
        original_image = image,
        ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
        one_hot_encoded_mask = reverse_one_hot(mask)
    )
In [47]:
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = select_classes
ACTIVATION = 'sigmoid'
In [48]:
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to C:\Users\basan/.cache\torch\hub\checkpoints\resnet50-19c8e357.pth
100%|█████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:26<00:00, 3.89MB/s]
In [49]:
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
In [50]:
train_dataset = EndoscopyDataset(
    train_df, 
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)
In [51]:
valid_dataset = EndoscopyDataset(
    valid_df, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)
In [52]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
In [53]:
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False, num_workers=4)
In [54]:
TRAINING = True
EPOCHS = 20
DEVICE = "cpu"
In [55]:
loss = smp.utils.losses.DiceLoss()
In [56]:
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]
In [57]:
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.00008),
])
In [58]:
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)
In [59]:
if os.path.exists('Polyp-CVC-SG_model.pth'):
    model = torch.load('Polyp-CVC-SG_model.pth', map_location=DEVICE)
In [60]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)
In [61]:
valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)
In [62]:
can_train = False
In [63]:
%%time

if can_train:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, './Polyp-CVC-SG_model.pth')
            print('Model saved!')
    train_logs_df = pd.DataFrame(train_logs_list)
    valid_logs_df = pd.DataFrame(valid_logs_list)
    train_logs_df.T
CPU times: total: 0 ns
Wall time: 0 ns
In [64]:
if os.path.exists('./Polyp-CVC-SG_model.pth'):
    best_model = torch.load('./Polyp-CVC-SG_model.pth', map_location=DEVICE)
    print('Loaded UNet model from this run.')
Loaded UNet model from this run.
In [65]:
test_dataset = EndoscopyDataset(
    valid_df, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)
In [66]:
test_dataloader = DataLoader(test_dataset)
In [67]:
test_dataset_vis = EndoscopyDataset(
    valid_df,
    class_rgb_values=select_class_rgb_values,
)
In [68]:
random_idx = random.randint(0, len(test_dataset_vis)-1)
In [69]:
image, mask = test_dataset_vis[random_idx]
In [70]:
visualize(
    original_image = image,
    ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask = reverse_one_hot(mask)
)
In [71]:
def crop_image(image, true_dimensions):
    return album.CenterCrop(p=1, height=true_dimensions[0], width=true_dimensions[1])(image=image)
In [72]:
sample_preds_folder = 'Sample-Predictions/'
if not os.path.exists(sample_preds_folder):
    os.makedirs(sample_preds_folder)
In [73]:
for idx in range(len(test_dataset)):

    image, gt_mask = test_dataset[idx]
    image_vis = test_dataset_vis[idx][0].astype('uint8')
    true_dimensions = image_vis.shape
    x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
    pred_mask = best_model(x_tensor)
    pred_mask = pred_mask.detach().squeeze().cpu().numpy()
    pred_mask = np.transpose(pred_mask,(1,2,0))
    pred_polyp_heatmap = crop_image(pred_mask[:,:,select_classes.index('polyp')], true_dimensions)['image']
    pred_mask = crop_image(colour_code_segmentation(reverse_one_hot(pred_mask), select_class_rgb_values), true_dimensions)['image']
    gt_mask = np.transpose(gt_mask,(1,2,0))
    gt_mask = crop_image(colour_code_segmentation(reverse_one_hot(gt_mask), select_class_rgb_values), true_dimensions)['image']
    cv2.imwrite(os.path.join(sample_preds_folder, f"sample_pred_{idx}.png"), np.hstack([image_vis, gt_mask, pred_mask])[:,:,::-1])
    
    visualize(
        original_image = image_vis,
        ground_truth_mask = gt_mask,
        predicted_mask = pred_mask,
        pred_polyp_heatmap = pred_polyp_heatmap
    )
In [74]:
test_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)
In [75]:
valid_logs = test_epoch.run(test_dataloader)
print("Evaluation on Test Data: ")
print(f"Mean IoU Score: {valid_logs['iou_score']:.4f}")
print(f"Mean Dice Loss: {valid_logs['dice_loss']:.4f}")
valid: 100%|███████████████████████████████████| 61/61 [00:16<00:00,  3.71it/s, dice_loss - 0.2745, iou_score - 0.8742]
Evaluation on Test Data: 
Mean IoU Score: 0.8742
Mean Dice Loss: 0.2745
In [76]:
if can_train:
    plt.figure(figsize=(20,8))
    plt.plot(train_logs_df.index.tolist(), train_logs_df.iou_score.tolist(), lw=3, label = 'Train')
    plt.plot(valid_logs_df.index.tolist(), valid_logs_df.iou_score.tolist(), lw=3, label = 'Valid')
    plt.xlabel('Epochs', fontsize=21)
    plt.ylabel('IoU Score', fontsize=21)
    plt.title('IoU Score Plot', fontsize=21)
    plt.legend(loc='best', fontsize=16)
    plt.grid()
    plt.savefig('iou_score_plot.png')
    plt.show()
else:
    display(Image(url= "iou_score_plot.png"))
In [77]:
if can_train:
    plt.figure(figsize=(20,8))
    plt.plot(train_logs_df.index.tolist(), train_logs_df.dice_loss.tolist(), lw=3, label = 'Train')
    plt.plot(valid_logs_df.index.tolist(), valid_logs_df.dice_loss.tolist(), lw=3, label = 'Valid')
    plt.xlabel('Epochs', fontsize=21)
    plt.ylabel('Dice Loss', fontsize=21)
    plt.title('Dice Loss Plot', fontsize=21)
    plt.legend(loc='best', fontsize=16)
    plt.grid()
    plt.savefig('dice_loss_plot.png')
    plt.show()
else:
    display(Image(url= "dice_loss_plot.png"))
In [ ]: